{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d5c5aa32",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.0.1'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ace7f4e",
   "metadata": {},
   "source": [
    "# 使用PyTorch计算梯度数值\n",
    "PyTorch的Autograd模块实现了深度学习的算法中的向传播求导数，在张量（Tensor类）上的所有操作，Autograd都能为他们自动提供微分，简化了手动计算导数的复杂过程。\n",
    "\n",
    "在0.4以前的版本中，Pytorch 使用 Variable 类来自动计算所有的梯度。Variable类主要包含三个属性： data：保存Variable所包含的Tensor；grad：保存data对应的梯度，grad也是个Variable，而不是Tensor，它和data的形状一样；grad_fn：指向一个Function对象，这个Function用来反向传播计算输入的梯度。\n",
    "\n",
    "从0.4起， Variable 正式合并入Tensor类，通过Variable嵌套实现的自动微分功能已经整合进入了Tensor类中。虽然为了代码的兼容性还是可以使用Variable(tensor)这种方式进行嵌套，但是这个操作其实什么都没做。\n",
    "\n",
    "所以，以后的代码建议直接使用Tensor类进行操作，因为官方文档中已经将Variable设置成过期模块。\n",
    "\n",
    "要想通过Tensor类本身就支持了使用autograd功能，只需要设置.requires_grad=True\n",
    "\n",
    "Variable类中的的grad和grad_fn属性已经整合进入了Tensor类中"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac108638",
   "metadata": {},
   "source": [
    "## Autograd\n",
    "在张量创建时，通过设置 requires_grad 标识为Ture来告诉Pytorch需要对该张量进行自动求导，PyTorch会记录该张量的每一步操作历史并自动计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e8f78920",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.5466, 0.4557, 0.8138, 0.7021, 0.3837],\n",
       "        [0.3876, 0.5698, 0.0593, 0.0470, 0.1703],\n",
       "        [0.6517, 0.8158, 0.8245, 0.9850, 0.4737],\n",
       "        [0.6552, 0.1595, 0.8717, 0.7866, 0.3089],\n",
       "        [0.5084, 0.7214, 0.0033, 0.5265, 0.7606]], requires_grad=True)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(5, 5, requires_grad=True)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e948322d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.7107, 0.4346, 0.3709, 0.4259, 0.7132],\n",
       "        [0.9172, 0.2533, 0.6769, 0.2052, 0.2006],\n",
       "        [0.5626, 0.9939, 0.7434, 0.7676, 0.5683],\n",
       "        [0.5391, 0.3264, 0.1772, 0.2826, 0.6309],\n",
       "        [0.6488, 0.8614, 0.1256, 0.7032, 0.4141]], requires_grad=True)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = torch.rand(5, 5, requires_grad=True)\n",
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "abc818ee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(26.4426, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z = torch.sum(x + y)\n",
    "z"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f1cb8ff",
   "metadata": {},
   "source": [
    "在张量进行操作后，grad_fn已经被赋予了一个新的函数，这个函数引用了一个创建了这个Tensor类的Function对象。 Tensor和Function互相连接生成了一个非循环图，它记录并且编码了完整的计算历史。每个张量都有一个.grad_fn属性，如果这个张量是用户手动创建的那么这个张量的grad_fn是None。\n",
    "\n",
    "下面我们来调用反向传播函数，计算其梯度\n",
    "\n",
    "## 简单的自动求导"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "495fdc89",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.]]) tensor([[1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.]])\n"
     ]
    }
   ],
   "source": [
    "z.backward()\n",
    "print(x.grad, y.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6613aca8",
   "metadata": {},
   "source": [
    "如果Tensor类表示的是一个标量（即它包含一个元素的张量），则不需要为backward()指定任何参数，但是如果它有更多的元素，则需要指定一个gradient参数，它是形状匹配的张量。 以上的 z.backward()相当于是z.backward(torch.tensor(1.))的简写。 这种参数常出现在图像分类中的单标签分类，输出一个标量代表图像的标签。\n",
    "\n",
    "## 复杂的自动求导"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "03fe1f7a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.7150, 0.5451, 0.7948, 0.9113, 0.6791],\n",
       "        [0.5611, 0.5545, 0.1474, 0.7216, 0.2016],\n",
       "        [1.3377, 0.6702, 0.7980, 1.2841, 1.0018],\n",
       "        [1.1722, 0.4004, 0.5327, 0.6950, 0.0217],\n",
       "        [0.5750, 1.1271, 0.0825, 0.5925, 1.0580]], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(5, 5, requires_grad=True)\n",
    "y = torch.rand(5, 5, requires_grad=True)\n",
    "z= x**2+y**3\n",
    "z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6e048703",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.7130, 0.9340, 0.5964, 1.5559, 0.4297],\n",
      "        [1.4818, 0.3395, 0.7494, 1.4430, 0.8968],\n",
      "        [1.7802, 1.5119, 1.6779, 1.6170, 1.9863],\n",
      "        [1.0223, 0.7314, 0.1781, 1.6490, 0.2947],\n",
      "        [1.3259, 1.0751, 0.1247, 0.9223, 1.9027]])\n"
     ]
    }
   ],
   "source": [
    "#我们的返回值不是一个标量，所以需要输入一个大小相同的张量作为参数，这里我们用ones_like函数根据x生成一个张量\n",
    "z.backward(torch.ones_like(x))\n",
    "print(x.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a3f9441",
   "metadata": {},
   "source": [
    "我们可以使用with torch.no_grad()上下文管理器临时禁止对已设置requires_grad=True的张量进行自动求导。这个方法在测试集计算准确率的时候会经常用到，例如："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4c741263",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    print((x + y*2).requires_grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0216af2",
   "metadata": {},
   "source": [
    "使用.no_grad()进行嵌套后，代码不会跟踪历史记录，也就是说保存的这部分记录会减少内存的使用量并且会加快少许的运算速度。\n",
    "\n",
    "## Autograd 过程解析\n",
    "为了说明Pytorch的自动求导原理，我们来尝试分析一下PyTorch的源代码，虽然Pytorch的 Tensor和 TensorBase都是使用CPP来实现的，但是可以使用一些Python的一些方法查看这些对象在Python的属性和状态。 Python的 dir() 返回参数的属性、方法列表。z是一个Tensor变量，看看里面有哪些成员变量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8fc66eb4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['H',\n",
       " 'T',\n",
       " '__abs__',\n",
       " '__add__',\n",
       " '__and__',\n",
       " '__array__',\n",
       " '__array_priority__',\n",
       " '__array_wrap__',\n",
       " '__bool__',\n",
       " '__class__',\n",
       " '__complex__',\n",
       " '__contains__',\n",
       " '__deepcopy__',\n",
       " '__delattr__',\n",
       " '__delitem__',\n",
       " '__dict__',\n",
       " '__dir__',\n",
       " '__div__',\n",
       " '__dlpack__',\n",
       " '__dlpack_device__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__float__',\n",
       " '__floordiv__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattribute__',\n",
       " '__getitem__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__iadd__',\n",
       " '__iand__',\n",
       " '__idiv__',\n",
       " '__ifloordiv__',\n",
       " '__ilshift__',\n",
       " '__imod__',\n",
       " '__imul__',\n",
       " '__index__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__int__',\n",
       " '__invert__',\n",
       " '__ior__',\n",
       " '__ipow__',\n",
       " '__irshift__',\n",
       " '__isub__',\n",
       " '__iter__',\n",
       " '__itruediv__',\n",
       " '__ixor__',\n",
       " '__le__',\n",
       " '__len__',\n",
       " '__long__',\n",
       " '__lshift__',\n",
       " '__lt__',\n",
       " '__matmul__',\n",
       " '__mod__',\n",
       " '__module__',\n",
       " '__mul__',\n",
       " '__ne__',\n",
       " '__neg__',\n",
       " '__new__',\n",
       " '__nonzero__',\n",
       " '__or__',\n",
       " '__pos__',\n",
       " '__pow__',\n",
       " '__radd__',\n",
       " '__rand__',\n",
       " '__rdiv__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__reversed__',\n",
       " '__rfloordiv__',\n",
       " '__rlshift__',\n",
       " '__rmatmul__',\n",
       " '__rmod__',\n",
       " '__rmul__',\n",
       " '__ror__',\n",
       " '__rpow__',\n",
       " '__rrshift__',\n",
       " '__rshift__',\n",
       " '__rsub__',\n",
       " '__rtruediv__',\n",
       " '__rxor__',\n",
       " '__setattr__',\n",
       " '__setitem__',\n",
       " '__setstate__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__sub__',\n",
       " '__subclasshook__',\n",
       " '__torch_dispatch__',\n",
       " '__torch_function__',\n",
       " '__truediv__',\n",
       " '__weakref__',\n",
       " '__xor__',\n",
       " '_addmm_activation',\n",
       " '_autocast_to_full_precision',\n",
       " '_autocast_to_reduced_precision',\n",
       " '_backward_hooks',\n",
       " '_base',\n",
       " '_cdata',\n",
       " '_coalesced_',\n",
       " '_conj',\n",
       " '_conj_physical',\n",
       " '_dimI',\n",
       " '_dimV',\n",
       " '_fix_weakref',\n",
       " '_grad',\n",
       " '_grad_fn',\n",
       " '_has_symbolic_sizes_strides',\n",
       " '_indices',\n",
       " '_is_all_true',\n",
       " '_is_any_true',\n",
       " '_is_view',\n",
       " '_is_zerotensor',\n",
       " '_make_subclass',\n",
       " '_make_wrapper_subclass',\n",
       " '_neg_view',\n",
       " '_nested_tensor_size',\n",
       " '_nested_tensor_strides',\n",
       " '_nnz',\n",
       " '_python_dispatch',\n",
       " '_reduce_ex_internal',\n",
       " '_to_dense',\n",
       " '_typed_storage',\n",
       " '_update_names',\n",
       " '_values',\n",
       " '_version',\n",
       " '_view_func',\n",
       " 'abs',\n",
       " 'abs_',\n",
       " 'absolute',\n",
       " 'absolute_',\n",
       " 'acos',\n",
       " 'acos_',\n",
       " 'acosh',\n",
       " 'acosh_',\n",
       " 'add',\n",
       " 'add_',\n",
       " 'addbmm',\n",
       " 'addbmm_',\n",
       " 'addcdiv',\n",
       " 'addcdiv_',\n",
       " 'addcmul',\n",
       " 'addcmul_',\n",
       " 'addmm',\n",
       " 'addmm_',\n",
       " 'addmv',\n",
       " 'addmv_',\n",
       " 'addr',\n",
       " 'addr_',\n",
       " 'adjoint',\n",
       " 'align_as',\n",
       " 'align_to',\n",
       " 'all',\n",
       " 'allclose',\n",
       " 'amax',\n",
       " 'amin',\n",
       " 'aminmax',\n",
       " 'angle',\n",
       " 'any',\n",
       " 'apply_',\n",
       " 'arccos',\n",
       " 'arccos_',\n",
       " 'arccosh',\n",
       " 'arccosh_',\n",
       " 'arcsin',\n",
       " 'arcsin_',\n",
       " 'arcsinh',\n",
       " 'arcsinh_',\n",
       " 'arctan',\n",
       " 'arctan2',\n",
       " 'arctan2_',\n",
       " 'arctan_',\n",
       " 'arctanh',\n",
       " 'arctanh_',\n",
       " 'argmax',\n",
       " 'argmin',\n",
       " 'argsort',\n",
       " 'argwhere',\n",
       " 'as_strided',\n",
       " 'as_strided_',\n",
       " 'as_strided_scatter',\n",
       " 'as_subclass',\n",
       " 'asin',\n",
       " 'asin_',\n",
       " 'asinh',\n",
       " 'asinh_',\n",
       " 'atan',\n",
       " 'atan2',\n",
       " 'atan2_',\n",
       " 'atan_',\n",
       " 'atanh',\n",
       " 'atanh_',\n",
       " 'backward',\n",
       " 'baddbmm',\n",
       " 'baddbmm_',\n",
       " 'bernoulli',\n",
       " 'bernoulli_',\n",
       " 'bfloat16',\n",
       " 'bincount',\n",
       " 'bitwise_and',\n",
       " 'bitwise_and_',\n",
       " 'bitwise_left_shift',\n",
       " 'bitwise_left_shift_',\n",
       " 'bitwise_not',\n",
       " 'bitwise_not_',\n",
       " 'bitwise_or',\n",
       " 'bitwise_or_',\n",
       " 'bitwise_right_shift',\n",
       " 'bitwise_right_shift_',\n",
       " 'bitwise_xor',\n",
       " 'bitwise_xor_',\n",
       " 'bmm',\n",
       " 'bool',\n",
       " 'broadcast_to',\n",
       " 'byte',\n",
       " 'cauchy_',\n",
       " 'ccol_indices',\n",
       " 'cdouble',\n",
       " 'ceil',\n",
       " 'ceil_',\n",
       " 'cfloat',\n",
       " 'chalf',\n",
       " 'char',\n",
       " 'cholesky',\n",
       " 'cholesky_inverse',\n",
       " 'cholesky_solve',\n",
       " 'chunk',\n",
       " 'clamp',\n",
       " 'clamp_',\n",
       " 'clamp_max',\n",
       " 'clamp_max_',\n",
       " 'clamp_min',\n",
       " 'clamp_min_',\n",
       " 'clip',\n",
       " 'clip_',\n",
       " 'clone',\n",
       " 'coalesce',\n",
       " 'col_indices',\n",
       " 'conj',\n",
       " 'conj_physical',\n",
       " 'conj_physical_',\n",
       " 'contiguous',\n",
       " 'copy_',\n",
       " 'copysign',\n",
       " 'copysign_',\n",
       " 'corrcoef',\n",
       " 'cos',\n",
       " 'cos_',\n",
       " 'cosh',\n",
       " 'cosh_',\n",
       " 'count_nonzero',\n",
       " 'cov',\n",
       " 'cpu',\n",
       " 'cross',\n",
       " 'crow_indices',\n",
       " 'cuda',\n",
       " 'cummax',\n",
       " 'cummin',\n",
       " 'cumprod',\n",
       " 'cumprod_',\n",
       " 'cumsum',\n",
       " 'cumsum_',\n",
       " 'data',\n",
       " 'data_ptr',\n",
       " 'deg2rad',\n",
       " 'deg2rad_',\n",
       " 'dense_dim',\n",
       " 'dequantize',\n",
       " 'det',\n",
       " 'detach',\n",
       " 'detach_',\n",
       " 'device',\n",
       " 'diag',\n",
       " 'diag_embed',\n",
       " 'diagflat',\n",
       " 'diagonal',\n",
       " 'diagonal_scatter',\n",
       " 'diff',\n",
       " 'digamma',\n",
       " 'digamma_',\n",
       " 'dim',\n",
       " 'dist',\n",
       " 'div',\n",
       " 'div_',\n",
       " 'divide',\n",
       " 'divide_',\n",
       " 'dot',\n",
       " 'double',\n",
       " 'dsplit',\n",
       " 'dtype',\n",
       " 'eig',\n",
       " 'element_size',\n",
       " 'eq',\n",
       " 'eq_',\n",
       " 'equal',\n",
       " 'erf',\n",
       " 'erf_',\n",
       " 'erfc',\n",
       " 'erfc_',\n",
       " 'erfinv',\n",
       " 'erfinv_',\n",
       " 'exp',\n",
       " 'exp2',\n",
       " 'exp2_',\n",
       " 'exp_',\n",
       " 'expand',\n",
       " 'expand_as',\n",
       " 'expm1',\n",
       " 'expm1_',\n",
       " 'exponential_',\n",
       " 'fill_',\n",
       " 'fill_diagonal_',\n",
       " 'fix',\n",
       " 'fix_',\n",
       " 'flatten',\n",
       " 'flip',\n",
       " 'fliplr',\n",
       " 'flipud',\n",
       " 'float',\n",
       " 'float_power',\n",
       " 'float_power_',\n",
       " 'floor',\n",
       " 'floor_',\n",
       " 'floor_divide',\n",
       " 'floor_divide_',\n",
       " 'fmax',\n",
       " 'fmin',\n",
       " 'fmod',\n",
       " 'fmod_',\n",
       " 'frac',\n",
       " 'frac_',\n",
       " 'frexp',\n",
       " 'gather',\n",
       " 'gcd',\n",
       " 'gcd_',\n",
       " 'ge',\n",
       " 'ge_',\n",
       " 'geometric_',\n",
       " 'geqrf',\n",
       " 'ger',\n",
       " 'get_device',\n",
       " 'grad',\n",
       " 'grad_fn',\n",
       " 'greater',\n",
       " 'greater_',\n",
       " 'greater_equal',\n",
       " 'greater_equal_',\n",
       " 'gt',\n",
       " 'gt_',\n",
       " 'half',\n",
       " 'hardshrink',\n",
       " 'has_names',\n",
       " 'heaviside',\n",
       " 'heaviside_',\n",
       " 'histc',\n",
       " 'histogram',\n",
       " 'hsplit',\n",
       " 'hypot',\n",
       " 'hypot_',\n",
       " 'i0',\n",
       " 'i0_',\n",
       " 'igamma',\n",
       " 'igamma_',\n",
       " 'igammac',\n",
       " 'igammac_',\n",
       " 'imag',\n",
       " 'index_add',\n",
       " 'index_add_',\n",
       " 'index_copy',\n",
       " 'index_copy_',\n",
       " 'index_fill',\n",
       " 'index_fill_',\n",
       " 'index_put',\n",
       " 'index_put_',\n",
       " 'index_reduce',\n",
       " 'index_reduce_',\n",
       " 'index_select',\n",
       " 'indices',\n",
       " 'inner',\n",
       " 'int',\n",
       " 'int_repr',\n",
       " 'inverse',\n",
       " 'ipu',\n",
       " 'is_coalesced',\n",
       " 'is_complex',\n",
       " 'is_conj',\n",
       " 'is_contiguous',\n",
       " 'is_cpu',\n",
       " 'is_cuda',\n",
       " 'is_distributed',\n",
       " 'is_floating_point',\n",
       " 'is_inference',\n",
       " 'is_ipu',\n",
       " 'is_leaf',\n",
       " 'is_meta',\n",
       " 'is_mkldnn',\n",
       " 'is_mps',\n",
       " 'is_neg',\n",
       " 'is_nested',\n",
       " 'is_nonzero',\n",
       " 'is_ort',\n",
       " 'is_pinned',\n",
       " 'is_quantized',\n",
       " 'is_same_size',\n",
       " 'is_set_to',\n",
       " 'is_shared',\n",
       " 'is_signed',\n",
       " 'is_sparse',\n",
       " 'is_sparse_csr',\n",
       " 'is_vulkan',\n",
       " 'is_xpu',\n",
       " 'isclose',\n",
       " 'isfinite',\n",
       " 'isinf',\n",
       " 'isnan',\n",
       " 'isneginf',\n",
       " 'isposinf',\n",
       " 'isreal',\n",
       " 'istft',\n",
       " 'item',\n",
       " 'kron',\n",
       " 'kthvalue',\n",
       " 'layout',\n",
       " 'lcm',\n",
       " 'lcm_',\n",
       " 'ldexp',\n",
       " 'ldexp_',\n",
       " 'le',\n",
       " 'le_',\n",
       " 'lerp',\n",
       " 'lerp_',\n",
       " 'less',\n",
       " 'less_',\n",
       " 'less_equal',\n",
       " 'less_equal_',\n",
       " 'lgamma',\n",
       " 'lgamma_',\n",
       " 'log',\n",
       " 'log10',\n",
       " 'log10_',\n",
       " 'log1p',\n",
       " 'log1p_',\n",
       " 'log2',\n",
       " 'log2_',\n",
       " 'log_',\n",
       " 'log_normal_',\n",
       " 'log_softmax',\n",
       " 'logaddexp',\n",
       " 'logaddexp2',\n",
       " 'logcumsumexp',\n",
       " 'logdet',\n",
       " 'logical_and',\n",
       " 'logical_and_',\n",
       " 'logical_not',\n",
       " 'logical_not_',\n",
       " 'logical_or',\n",
       " 'logical_or_',\n",
       " 'logical_xor',\n",
       " 'logical_xor_',\n",
       " 'logit',\n",
       " 'logit_',\n",
       " 'logsumexp',\n",
       " 'long',\n",
       " 'lstsq',\n",
       " 'lt',\n",
       " 'lt_',\n",
       " 'lu',\n",
       " 'lu_solve',\n",
       " 'mH',\n",
       " 'mT',\n",
       " 'map2_',\n",
       " 'map_',\n",
       " 'masked_fill',\n",
       " 'masked_fill_',\n",
       " 'masked_scatter',\n",
       " 'masked_scatter_',\n",
       " 'masked_select',\n",
       " 'matmul',\n",
       " 'matrix_exp',\n",
       " 'matrix_power',\n",
       " 'max',\n",
       " 'maximum',\n",
       " 'mean',\n",
       " 'median',\n",
       " 'min',\n",
       " 'minimum',\n",
       " 'mm',\n",
       " 'mode',\n",
       " 'moveaxis',\n",
       " 'movedim',\n",
       " 'msort',\n",
       " 'mul',\n",
       " 'mul_',\n",
       " 'multinomial',\n",
       " 'multiply',\n",
       " 'multiply_',\n",
       " 'mv',\n",
       " 'mvlgamma',\n",
       " 'mvlgamma_',\n",
       " 'name',\n",
       " 'names',\n",
       " 'nan_to_num',\n",
       " 'nan_to_num_',\n",
       " 'nanmean',\n",
       " 'nanmedian',\n",
       " 'nanquantile',\n",
       " 'nansum',\n",
       " 'narrow',\n",
       " 'narrow_copy',\n",
       " 'ndim',\n",
       " 'ndimension',\n",
       " 'ne',\n",
       " 'ne_',\n",
       " 'neg',\n",
       " 'neg_',\n",
       " 'negative',\n",
       " 'negative_',\n",
       " 'nelement',\n",
       " 'new',\n",
       " 'new_empty',\n",
       " 'new_empty_strided',\n",
       " 'new_full',\n",
       " 'new_ones',\n",
       " 'new_tensor',\n",
       " 'new_zeros',\n",
       " 'nextafter',\n",
       " 'nextafter_',\n",
       " 'nonzero',\n",
       " 'norm',\n",
       " 'normal_',\n",
       " 'not_equal',\n",
       " 'not_equal_',\n",
       " 'numel',\n",
       " 'numpy',\n",
       " 'orgqr',\n",
       " 'ormqr',\n",
       " 'outer',\n",
       " 'output_nr',\n",
       " 'permute',\n",
       " 'pin_memory',\n",
       " 'pinverse',\n",
       " 'polygamma',\n",
       " 'polygamma_',\n",
       " 'positive',\n",
       " 'pow',\n",
       " 'pow_',\n",
       " 'prelu',\n",
       " 'prod',\n",
       " 'put',\n",
       " 'put_',\n",
       " 'q_per_channel_axis',\n",
       " 'q_per_channel_scales',\n",
       " 'q_per_channel_zero_points',\n",
       " 'q_scale',\n",
       " 'q_zero_point',\n",
       " 'qr',\n",
       " 'qscheme',\n",
       " 'quantile',\n",
       " 'rad2deg',\n",
       " 'rad2deg_',\n",
       " 'random_',\n",
       " 'ravel',\n",
       " 'real',\n",
       " 'reciprocal',\n",
       " 'reciprocal_',\n",
       " 'record_stream',\n",
       " 'refine_names',\n",
       " 'register_hook',\n",
       " 'reinforce',\n",
       " 'relu',\n",
       " 'relu_',\n",
       " 'remainder',\n",
       " 'remainder_',\n",
       " 'rename',\n",
       " 'rename_',\n",
       " 'renorm',\n",
       " 'renorm_',\n",
       " 'repeat',\n",
       " 'repeat_interleave',\n",
       " 'requires_grad',\n",
       " 'requires_grad_',\n",
       " 'reshape',\n",
       " 'reshape_as',\n",
       " 'resize',\n",
       " 'resize_',\n",
       " 'resize_as',\n",
       " 'resize_as_',\n",
       " 'resize_as_sparse_',\n",
       " 'resolve_conj',\n",
       " 'resolve_neg',\n",
       " 'retain_grad',\n",
       " 'retains_grad',\n",
       " 'roll',\n",
       " 'rot90',\n",
       " 'round',\n",
       " 'round_',\n",
       " 'row_indices',\n",
       " 'rsqrt',\n",
       " 'rsqrt_',\n",
       " 'scatter',\n",
       " 'scatter_',\n",
       " 'scatter_add',\n",
       " 'scatter_add_',\n",
       " 'scatter_reduce',\n",
       " 'scatter_reduce_',\n",
       " 'select',\n",
       " 'select_scatter',\n",
       " 'set_',\n",
       " 'sgn',\n",
       " 'sgn_',\n",
       " 'shape',\n",
       " 'share_memory_',\n",
       " 'short',\n",
       " 'sigmoid',\n",
       " 'sigmoid_',\n",
       " 'sign',\n",
       " 'sign_',\n",
       " 'signbit',\n",
       " 'sin',\n",
       " 'sin_',\n",
       " 'sinc',\n",
       " 'sinc_',\n",
       " 'sinh',\n",
       " 'sinh_',\n",
       " 'size',\n",
       " 'slice_scatter',\n",
       " 'slogdet',\n",
       " 'smm',\n",
       " 'softmax',\n",
       " 'solve',\n",
       " 'sort',\n",
       " 'sparse_dim',\n",
       " 'sparse_mask',\n",
       " 'sparse_resize_',\n",
       " 'sparse_resize_and_clear_',\n",
       " 'split',\n",
       " 'split_with_sizes',\n",
       " 'sqrt',\n",
       " 'sqrt_',\n",
       " 'square',\n",
       " 'square_',\n",
       " 'squeeze',\n",
       " 'squeeze_',\n",
       " 'sspaddmm',\n",
       " 'std',\n",
       " 'stft',\n",
       " 'storage',\n",
       " 'storage_offset',\n",
       " 'storage_type',\n",
       " 'stride',\n",
       " 'sub',\n",
       " 'sub_',\n",
       " 'subtract',\n",
       " 'subtract_',\n",
       " 'sum',\n",
       " 'sum_to_size',\n",
       " 'svd',\n",
       " 'swapaxes',\n",
       " 'swapaxes_',\n",
       " 'swapdims',\n",
       " 'swapdims_',\n",
       " 'symeig',\n",
       " 't',\n",
       " 't_',\n",
       " 'take',\n",
       " 'take_along_dim',\n",
       " 'tan',\n",
       " 'tan_',\n",
       " 'tanh',\n",
       " 'tanh_',\n",
       " 'tensor_split',\n",
       " 'tile',\n",
       " 'to',\n",
       " 'to_dense',\n",
       " 'to_mkldnn',\n",
       " 'to_padded_tensor',\n",
       " 'to_sparse',\n",
       " 'to_sparse_bsc',\n",
       " 'to_sparse_bsr',\n",
       " 'to_sparse_coo',\n",
       " 'to_sparse_csc',\n",
       " 'to_sparse_csr',\n",
       " 'tolist',\n",
       " 'topk',\n",
       " 'trace',\n",
       " 'transpose',\n",
       " 'transpose_',\n",
       " 'triangular_solve',\n",
       " 'tril',\n",
       " 'tril_',\n",
       " 'triu',\n",
       " 'triu_',\n",
       " 'true_divide',\n",
       " 'true_divide_',\n",
       " 'trunc',\n",
       " 'trunc_',\n",
       " 'type',\n",
       " 'type_as',\n",
       " 'unbind',\n",
       " 'unflatten',\n",
       " 'unfold',\n",
       " 'uniform_',\n",
       " 'unique',\n",
       " 'unique_consecutive',\n",
       " 'unsafe_chunk',\n",
       " 'unsafe_split',\n",
       " 'unsafe_split_with_sizes',\n",
       " 'unsqueeze',\n",
       " 'unsqueeze_',\n",
       " 'untyped_storage',\n",
       " 'values',\n",
       " 'var',\n",
       " 'vdot',\n",
       " 'view',\n",
       " 'view_as',\n",
       " 'vsplit',\n",
       " 'where',\n",
       " 'xlogy',\n",
       " 'xlogy_',\n",
       " 'xpu',\n",
       " 'zero_']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(z)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40df44eb",
   "metadata": {},
   "source": [
    "返回很多，我们直接排除掉一些Python中特殊方法（以__开头和结束的）和私有方法（以_开头的，直接看几个比较主要的属性： .is_leaf：记录是否是叶子节点。通过这个属性来确定这个变量的类型 在官方文档中所说的“graph leaves”，“leaf variables”，都是指像x，y这样的手动创建的、而非运算得到的变量，这些变量称为创建变量。 像z这样的，是通过计算后得到的结果称为结果变量。\n",
    "\n",
    "一个变量是创建变量还是结果变量是通过.is_leaf来获取的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "875b1e61",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x.is_leaf=True\n",
      "z.is_leaf=False\n"
     ]
    }
   ],
   "source": [
    "print(\"x.is_leaf=\"+str(x.is_leaf))\n",
    "print(\"z.is_leaf=\"+str(z.is_leaf))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45fe7935",
   "metadata": {},
   "source": [
    "x是手动创建的没有通过计算，所以他被认为是一个叶子节点也就是一个创建变量，而z是通过x与y的一系列计算得到的，所以不是叶子结点也就是结果变量。\n",
    "\n",
    "为什么我们执行z.backward()方法会更新x.grad和y.grad呢？ .grad_fn属性记录的就是这部分的操作，虽然.backward()方法也是CPP实现的，但是可以通过Python来进行简单的探索。\n",
    "\n",
    "grad_fn：记录并且编码了完整的计算历史"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f5fc82ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AddBackward0 at 0x11fd5e640>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z.grad_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ec1d07f",
   "metadata": {},
   "source": [
    "grad_fn是一个AddBackward0类型的变量 AddBackward0这个类也是用Cpp来写的，但是我们从名字里就能够大概知道，他是加法(ADD)的反反向传播（Backward），看看里面有些什么东西"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c70e8721",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__call__',\n",
       " '__class__',\n",
       " '__delattr__',\n",
       " '__dir__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattribute__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__le__',\n",
       " '__lt__',\n",
       " '__ne__',\n",
       " '__new__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__setattr__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__subclasshook__',\n",
       " '_register_hook_dict',\n",
       " '_saved_alpha',\n",
       " 'metadata',\n",
       " 'name',\n",
       " 'next_functions',\n",
       " 'register_hook',\n",
       " 'register_prehook',\n",
       " 'requires_grad']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(z.grad_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cba2be5b",
   "metadata": {},
   "source": [
    "next_functions就是grad_fn的精华"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6add6053",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((<PowBackward0 at 0x11fd5ebb0>, 0), (<PowBackward0 at 0x11fd5e6a0>, 0))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z.grad_fn.next_functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6c836a3e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__call__',\n",
       " '__class__',\n",
       " '__delattr__',\n",
       " '__dir__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattribute__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__le__',\n",
       " '__lt__',\n",
       " '__ne__',\n",
       " '__new__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__setattr__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__subclasshook__',\n",
       " '_raw_saved_self',\n",
       " '_register_hook_dict',\n",
       " '_saved_exponent',\n",
       " '_saved_self',\n",
       " 'metadata',\n",
       " 'name',\n",
       " 'next_functions',\n",
       " 'register_hook',\n",
       " 'register_prehook',\n",
       " 'requires_grad']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xg = z.grad_fn.next_functions[0][0]\n",
    "dir(xg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8a10062f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AccumulateGrad at 0x11fd5e970>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_leaf=xg.next_functions[0][0]\n",
    "x_leaf\n",
    "#type(x_leaf)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "101edd12",
   "metadata": {},
   "source": [
    "在PyTorch的反向图计算中，AccumulateGrad类型代表的就是叶子节点类型，也就是计算图终止节点。AccumulateGrad类中有一个.variable属性指向叶子节点。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "cbae89dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.3565, 0.4670, 0.2982, 0.7779, 0.2149],\n",
       "        [0.7409, 0.1698, 0.3747, 0.7215, 0.4484],\n",
       "        [0.8901, 0.7559, 0.8389, 0.8085, 0.9932],\n",
       "        [0.5111, 0.3657, 0.0891, 0.8245, 0.1474],\n",
       "        [0.6629, 0.5375, 0.0624, 0.4612, 0.9513]], requires_grad=True)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_leaf.variable"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e06c2216",
   "metadata": {},
   "source": [
    "这个.variable的属性就是我们的生成的变量x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e490e78b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_leaf.variable的id:4461131952\n",
      "x的id:4461131952\n"
     ]
    }
   ],
   "source": [
    "print(\"x_leaf.variable的id:\"+str(id(x_leaf.variable)))\n",
    "print(\"x的id:\"+str(id(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d71cfb59",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert(id(x_leaf.variable)==id(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb4dacfe",
   "metadata": {},
   "source": [
    "这样整个规程就很清晰了：\n",
    "\n",
    "当我们执行z.backward()的时候。这个操作将调用z里面的grad_fn这个属性，执行求导的操作。\n",
    "这个操作将遍历grad_fn的next_functions，然后分别取出里面的Function（AccumulateGrad），执行求导操作。这部分是一个递归的过程直到最后类型为叶子节点。\n",
    "计算出结果以后，将结果保存到他们对应的variable 这个变量所引用的对象（x和y）的 grad这个属性里面。\n",
    "求导结束。所有的叶节点的grad变量都得到了相应的更新\n",
    "最终当我们执行完c.backward()之后，a和b里面的grad值就得到了更新。\n",
    "\n",
    "## 扩展Autograd\n",
    "如果需要自定义autograd扩展新的功能，就需要扩展Function类。因为Function使用autograd来计算结果和梯度，并对操作历史进行编码。 在Function类中最主要的方法就是forward()和backward()他们分别代表了前向传播和反向传播。\n",
    "\n",
    "一个自定义的Function需要一下三个方法：\n",
    "\n",
    "`__init__`(optional)：如果这个操作需要额外的参数则需要定义这个Function的构造函数，不需要的话可以忽略。\n",
    "\n",
    "forward()：执行前向传播的计算代码\n",
    "\n",
    "backward()：反向传播时梯度计算的代码。 参数的个数和forward返回值的个数一样，每个参数代表传回到此操作的梯度。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "d2l",
   "language": "python",
   "name": "d2l"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
